import torch
from torch.utils.data import Dataset
#import cv2
from PIL import Image
import os#在获取图片的地址的时候需要用
class MyData(Dataset):

    #获取图片的标签
    def __init__(self,root_dir,label_dir):
        # root_dir = "dataset/train"
        # label_dir = "ants"
        self.root_dir = root_dir#这样的作用是将变量self.root_dir变成全局变量,在这个类的不同函数中都可以使用
        self.label_dir = label_dir
        self.path = os.path.join(self.root_dir,self.label_dir)
        self.img_path = os.listdir(self.path)
        # 获取获取到了每一张图片的地址

    #获取每一个图片
    def __getitem__(self, idx):
        img_name = self.img_path[idx]#获取每张照片的名称
        img_item_path = os.path.join(self.root_dir,self.label_dir,img_name)#获取每张图片的路径
        img = Image.open(img_item_path)
        label = self.label_dir
        return img,label

    #获取数据的长度
    def __len__(self):
        return len(self.img_path)


root_dir = "dataset/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir,ants_label_dir)
bees_dataset = MyData(root_dir,bees_label_dir)
train_dataset = ants_dataset + bees_dataset
